# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""the module is used to process images."""

import copy

import numpy as np


def rand_init(a=0., b=1.):
    return np.random.rand() * (b - a) + a


def bbox_iou(bbox_a, bbox_b, offset=0):
    """Calculate Intersection-Over-Union(IOU) of two bounding boxes.

    Parameters
    ----------
    bbox_a : numpy.ndarray
        An ndarray with shape :math:`(N, 4)`.
    bbox_b : numpy.ndarray
        An ndarray with shape :math:`(M, 4)`.
    offset : float or int, default is 0
        The ``offset`` is used to control the whether the width(or height) is computed as
        (right - left + ``offset``).
        Note that the offset must be 0 for normalized bboxes, whose ranges are in ``[0, 1]``.

    Returns
    -------
    numpy.ndarray
        An ndarray with shape :math:`(N, M)` indicates IOU between each pairs of
        bounding boxes in `bbox_a` and `bbox_b`.

    """
    if bbox_a.shape[1] < 4 or bbox_b.shape[1] < 4:
        raise IndexError("Bounding boxes axis 1 must have at least length 4")

    tl = np.maximum(bbox_a[:, None, :2], bbox_b[:, :2])
    br = np.minimum(bbox_a[:, None, 2:4], bbox_b[:, 2:4])

    area_i = np.prod(br - tl + offset, axis=2) * (tl < br).all(axis=2)
    area_a = np.prod(bbox_a[:, 2:4] - bbox_a[:, :2] + offset, axis=1)
    area_b = np.prod(bbox_b[:, 2:4] - bbox_b[:, :2] + offset, axis=1)
    return area_i / (area_a[:, None] + area_b - area_i)


def is_iou_satisfied_constraint(box, crop_box, min_iou=None, max_iou=None):
    """Filter satisfied constraint IOU."""
    if not min_iou:
        iou = bbox_iou(box, crop_box)
        satisfied = np.any((iou >= 1.0))
    else:
        iou = bbox_iou(box, crop_box, min_iou)
        satisfied = min_iou <= iou.min() and max_iou >= iou.max()
    return satisfied


def choose_candidate_by_constraints(max_trial,
                                    input_w, input_h, image_w, image_h,
                                    jitter, box, use_constraints):
    """Choose candidate by constraints."""
    if use_constraints:
        constraints = ((0.1, None),
                       (0.3, None),
                       (0.5, None),
                       (0.7, None),
                       (0.9, None),
                       (None, 1),)
    else:
        constraints = ((None, None),)
    # add default candidate
    candidates = [(0, 0, input_w, input_h)]
    for constraint in constraints:
        min_iou, max_iou = constraint
        min_iou = -np.inf if min_iou is None else min_iou
        max_iou = np.inf if max_iou is None else max_iou

        for _ in range(max_trial):
            # box_data should have at least one box
            new_ar = float(input_w) / \
                     float(input_h) * rand_init(1 - jitter, 1 + jitter) / \
                     rand_init(1 - jitter, 1 + jitter)
            scale = rand_init(0.25, 2)

            if new_ar < 1:
                nh = int(scale * input_h)
                nw = int(nh * new_ar)
            else:
                nw = int(scale * input_w)
                nh = int(nw / new_ar)

            dx = int(rand_init(0, input_w - nw))
            dy = int(rand_init(0, input_h - nh))

            if box.size > 0:
                t_box = copy.deepcopy(box)
                t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(
                    image_w) + dx
                t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(
                    image_h) + dy

                crop_box = np.array((0, 0, input_w, input_h))
                if not is_iou_satisfied_constraint(
                        t_box, crop_box[np.newaxis],
                        min_iou, max_iou
                ):
                    continue
                else:
                    candidates.append((dx, dy, nw, nh))
            else:
                raise Exception("Box size should be greater than 0, but get {}.".format(box.size))
    return candidates


def correct_bbox_by_candidates(candidates,
                               input_w, input_h, image_w, image_h,
                               flip, box, box_data, allow_outside_center):
    """Calculate correct boxes.
    Args:
        candidates (list): object boxes, such as [(210, 291, 352, 403), .....].
        input_w (int): input image width.
        input_h (int): input image high.
        image_w (int): original image width.
        image_h (int): original image high.
        flip (bool): image flip flag. Default(True)
        box (list): its size is [2,4], save object boxes.
        box_data (ndarray): input box data, default is array([[0,0,0,0],[0,0,0,0],....])
        allow_outside_center: default is True
    Examples:
        >>> box_data, candidate = correct_bbox_by_candidates(
        candidates, input_w, input_h, image_w, image_h,
        flip, box, box_data, allow_outside_center
        )
    """
    while candidates:
        if len(candidates) > 1:
            # ignore default candidate which do not crop
            candidate = candidates.pop(
                np.random.randint(1, len(candidates)))
        else:
            candidate = candidates.pop(
                np.random.randint(0, len(candidates)))
        dx, dy, nw, nh = candidate
        t_box = copy.deepcopy(box)
        t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(image_w) + dx
        t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(image_h) + dy
        if flip:
            t_box[:, [0, 2]] = input_w - t_box[:, [2, 0]]

        if allow_outside_center:
            pass
        else:
            t_box = t_box[
                np.logical_and((t_box[:, 0] + t_box[:, 2]) / 2. >= 0.,
                               (t_box[:, 1] + t_box[:, 3]) / 2. >= 0.)]
            t_box = t_box[
                np.logical_and((t_box[:, 0] + t_box[:, 2]) / 2. <= input_w,
                               (t_box[:, 1] + t_box[:, 3]) / 2. <= input_h)]

        # re-correct x, y for case x,y < 0 reset to zero, after dx and dy, some box can smaller than zero
        t_box[:, 0:2][t_box[:, 0:2] < 0] = 0
        # re-correct w,h not higher than input size
        t_box[:, 2][t_box[:, 2] > input_w] = input_w
        t_box[:, 3][t_box[:, 3] > input_h] = input_h
        box_w = t_box[:, 2] - t_box[:, 0]
        box_h = t_box[:, 3] - t_box[:, 1]
        # discard invalid box: w or h smaller than 1 pixel
        t_box = t_box[np.logical_and(box_w > 1, box_h > 1)]

        if t_box.shape[0] > 0:
            # break if number of find t_box
            box_data[: len(t_box)] = t_box
            return box_data, candidate
    raise Exception("All box candidates can not satisfied re-correct bbox")


def proposal_crop_areas(max_trial, image_w, image_h, boxes):
    """randomly select candidate regions.
    Args:
        max_trial (int): random times.
        image_w (int): original image width.
        image_h (int): original image high.
        boxes （numpy.ndarray）: ground truth boxes
    Examples:
        >>> box_data, candidate = proposal_crop_areas(
        max_trial, image_w, image_h, boxes
        )
    """
    # add default candidate
    candidates = [(0, 0, image_w, image_h)]

    for _ in range(max_trial):
        # box_data should have at least one box
        if rand_init() > 0.2:
            scale = rand_init(0.3, 1.0)
        else:
            scale = 1.0

        nh = int(scale * min(image_w, image_h))
        nw = nh

        dx = int(rand_init(0, image_w - nw))
        dy = int(rand_init(0, image_h - nh))

        if boxes.shape[0] > 0:
            crop_box = np.array((dx, dy, dx + nw, dy + nh))
            if not is_iou_satisfied_constraint(boxes, crop_box[np.newaxis]):
                continue
            else:
                candidates.append((dx, dy, nw, nh))
        else:
            raise Exception("!!! annotation box is less than 1")

        if len(candidates) >= 3:
            break

    return candidates


def modify_annotation_by_proposal_crop_areas(candidates, input_w, input_h, flip, boxes,
                                             labels, landms, allow_outside_center):
    """Calculate correct boxes and landms and labels.
    Args:
        candidates (list): object boxes, such as [(210, 291, 352, 403), .....].
        input_w (int): image width.
        input_h (int): image high.
        flip (bool): image flip flag. Default(True)
        box (list): its size is [2,4], save object boxes.
        labels:[2,1]
        landms:[2,10]
        allow_outside_center: default is True
    Examples:
        >>> targets_data, candidate = correct_bbox_by_candidates(
        candidates, input_w, input_h, image_w, image_h,
        flip, box, box_data, allow_outside_center
        )
    """
    while candidates:
        if len(candidates) > 1:
            # ignore default candidate which do not crop
            candidate = candidates.pop(np.random.randint(1, len(candidates)))
        else:
            candidate = candidates.pop(np.random.randint(0, len(candidates)))
        dx, dy, nw, nh = candidate

        boxes_t = copy.deepcopy(boxes)
        landms_t = copy.deepcopy(landms)
        labels_t = copy.deepcopy(labels)
        landms_t = landms_t.reshape([-1, 5, 2])

        if nw == nh:
            scale = float(input_w) / float(nw)
        else:
            scale = float(input_w) / float(max(nh, nw))
        boxes_t[:, [0, 2]] = (boxes_t[:, [0, 2]] - dx) * scale
        boxes_t[:, [1, 3]] = (boxes_t[:, [1, 3]] - dy) * scale
        landms_t[:, :, 0] = (landms_t[:, :, 0] - dx) * scale
        landms_t[:, :, 1] = (landms_t[:, :, 1] - dy) * scale

        if flip:
            boxes_t[:, [0, 2]] = input_w - boxes_t[:, [2, 0]]
            landms_t[:, :, 0] = input_w - landms_t[:, :, 0]
            # flip landms
            landms_t_1 = landms_t[:, 1, :].copy()
            landms_t[:, 1, :] = landms_t[:, 0, :]
            landms_t[:, 0, :] = landms_t_1
            landms_t_4 = landms_t[:, 4, :].copy()
            landms_t[:, 4, :] = landms_t[:, 3, :]
            landms_t[:, 3, :] = landms_t_4

        if allow_outside_center:
            pass
        else:
            mask1 = np.logical_and((boxes_t[:, 0] + boxes_t[:, 2]) / 2. >= 0.,
                                   (boxes_t[:, 1] + boxes_t[:, 3]) / 2. >= 0.)
            boxes_t = boxes_t[mask1]
            landms_t = landms_t[mask1]
            labels_t = labels_t[mask1]

            mask2 = np.logical_and((boxes_t[:, 0] + boxes_t[:, 2]) / 2. <= input_w,
                                   (boxes_t[:, 1] + boxes_t[:, 3]) / 2. <= input_h)
            boxes_t = boxes_t[mask2]
            landms_t = landms_t[mask2]
            labels_t = labels_t[mask2]

        # recorrect x, y for case x,y < 0 reset to zero, after dx and dy, some box can smaller than zero
        boxes_t[:, 0:2][boxes_t[:, 0:2] < 0] = 0
        # recorrect w,h not higher than input size
        boxes_t[:, 2][boxes_t[:, 2] > input_w] = input_w
        boxes_t[:, 3][boxes_t[:, 3] > input_h] = input_h
        box_w = boxes_t[:, 2] - boxes_t[:, 0]
        box_h = boxes_t[:, 3] - boxes_t[:, 1]
        # discard invalid box: w or h smaller than 1 pixel
        mask3 = np.logical_and(box_w > 1, box_h > 1)
        boxes_t = boxes_t[mask3]
        landms_t = landms_t[mask3]
        labels_t = labels_t[mask3]

        # normal
        boxes_t[:, [0, 2]] /= input_w
        boxes_t[:, [1, 3]] /= input_h
        landms_t[:, :, 0] /= input_w
        landms_t[:, :, 1] /= input_h

        landms_t = landms_t.reshape([-1, 10])
        labels_t = np.expand_dims(labels_t, 1)
        targets_t = np.hstack((boxes_t, landms_t, labels_t))

        if boxes_t.shape[0] > 0:
            return targets_t, candidate

    raise Exception('all candidates can not satisfied re-correct bbox')

def center_point_2_box(boxes):
    return np.concatenate((boxes[:, 0:2] - boxes[:, 2:4] / 2,
                           boxes[:, 0:2] + boxes[:, 2:4] / 2), axis=1)

def compute_intersect(a, b):
    """compute intersect"""
    a0 = a.shape[0]
    b0 = b.shape[0]
    max_xy = np.minimum(
        np.broadcast_to(np.expand_dims(a[:, 2:4], 1), [a0, b0, 2]),
        np.broadcast_to(np.expand_dims(b[:, 2:4], 0), [a0, b0, 2]))
    min_xy = np.maximum(
        np.broadcast_to(np.expand_dims(a[:, 0:2], 1), [a0, b0, 2]),
        np.broadcast_to(np.expand_dims(b[:, 0:2], 0), [a0, b0, 2]))
    inter = np.maximum((max_xy - min_xy), np.zeros_like(max_xy - min_xy))
    return inter[:, :, 0] * inter[:, :, 1]

def compute_overlaps(a, b):
    """compute overlaps"""
    inter = compute_intersect(a, b)
    area_a = np.broadcast_to(
        np.expand_dims(
            (a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), 1),
        np.shape(inter))
    area_b = np.broadcast_to(
        np.expand_dims(
            (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1]), 0),
        np.shape(inter))
    union = area_a + area_b - inter
    return inter / union


def match(threshold, boxes, priors, var, labels, landms):
    """match bboxes and GT.
    Parameters
    ----------
    threshold (float): numpy.ndarray.
    bboxes : numpy.ndarray.
        An ndarray with shape :math:`(num_boxes, 4)`.
    priors :numpy.ndarray
        An ndarray with shape :math:`(M, 4)`.
    var (list) : [0.1, 0.2].
    labels : numpy.ndarray.
        An ndarray with shape :math:`(num_boxes, 1)`.
    landms : numpy.ndarray
        An ndarray with shape :math:`(num_boxes, 10)`.
    Returns
    -------
    numpy.ndarray
        loc: (M, 4)
        conf: (M, 1)
        landm:(M, 10)

    """
    overlaps = compute_overlaps(boxes, center_point_2_box(priors))

    best_prior_overlap = overlaps.max(1, keepdims=True)
    best_prior_idx = np.argsort(-overlaps, axis=1)[:, 0:1]

    valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
    best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
    if best_prior_idx_filter.shape[0] <= 0:
        loc = np.zeros((priors.shape[0], 4), dtype=np.float32)
        conf = np.zeros((priors.shape[0],), dtype=np.int32)
        landm = np.zeros((priors.shape[0], 10), dtype=np.float32)
        return loc, conf, landm

    best_truth_overlap = overlaps.max(0, keepdims=True)
    best_truth_idx = np.argsort(-overlaps, axis=0)[:1, :]

    best_truth_idx = best_truth_idx.squeeze(0)
    best_truth_overlap = best_truth_overlap.squeeze(0)
    best_prior_idx = best_prior_idx.squeeze(1)
    best_prior_idx_filter = best_prior_idx_filter.squeeze(1)
    best_truth_overlap[best_prior_idx_filter] = 2

    for j in range(best_prior_idx.shape[0]):
        best_truth_idx[best_prior_idx[j]] = j

    matches = boxes[best_truth_idx]

    # encode boxes
    offset_cxcy = (matches[:, 0:2] + matches[:, 2:4]) / 2 - priors[:, 0:2]
    offset_cxcy /= (var[0] * priors[:, 2:4])
    wh = (matches[:, 2:4] - matches[:, 0:2]) / priors[:, 2:4]
    wh[wh == 0] = 1e-12
    wh = np.log(wh) / var[1]
    loc = np.concatenate([offset_cxcy, wh], axis=1)

    conf = labels[best_truth_idx]
    conf[best_truth_overlap < threshold] = 0

    matches_landm = landms[best_truth_idx]

    # encode landms
    matched = np.reshape(matches_landm, [-1, 5, 2])
    priors = np.broadcast_to(np.expand_dims(priors, 1), [priors.shape[0], 5, 4])
    offset_cxcy = matched[:, :, 0:2] - priors[:, :, 0:2]
    offset_cxcy /= (priors[:, :, 2:4] * var[0])
    landm = np.reshape(offset_cxcy, [-1, 10])

    return loc, np.array(conf, dtype=np.int32), landm


def prior_box(image_sizes, min_sizes, steps, clip=False):
    """prior box.
    Args:
        image_sizes (tuple): input image (high, width).
        min_sizes (list): anchor area.
        steps (tuple): feature maps [c3,c4,c5] scaling ratio.
        image_h (int): original image high.
        clip (bool): True:limit the coding coordinates to (0, 1).
    Examples:
        >>> priors = prior_box(image_sizes, min_sizes, steps)
    """
    feature_maps = [
        [math.ceil(image_sizes[0] / step), math.ceil(image_sizes[1] / step)]
        for step in steps]

    anchors = []
    for k, f in enumerate(feature_maps):
        for i, j in product(range(f[0]), range(f[1])):
            for min_size in min_sizes[k]:
                s_kx = min_size / image_sizes[1]
                s_ky = min_size / image_sizes[0]
                cx = (j + 0.5) * steps[k] / image_sizes[1]
                cy = (i + 0.5) * steps[k] / image_sizes[0]
                anchors += [cx, cy, s_kx, s_ky]

    output = np.asarray(anchors).reshape([-1, 4]).astype(np.float32)

    if clip:
        output = np.clip(output, 0, 1)

    return output
